-
Notifications
You must be signed in to change notification settings - Fork 71
Feat/simple mse metric #388 #425
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Feat/simple mse metric #388 #425
Conversation
begumcig
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, this is already almost flawless, asked for some small changes but it is almost ready to be merged. Thanks a lot @AnikethBhosale
| return | ||
|
|
||
| # Ensure tensors are on the same device | ||
| output_tensor = output_tensor.to(gt_tensor.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is a great idea, that's why we have integrated device casting in the metric_data_processor. how do you feel about passing the device to it instead?
| The model predictions/outputs. | ||
| """ | ||
| # Process inputs based on call_type (returns tuple of tensors) | ||
| inputs = metric_data_processor(x, gt, outputs, self.call_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can pass the device here (regarding the comment below)
| @@ -0,0 +1,247 @@ | |||
| # Copyright 2025 - Pruna AI GmbH. All rights reserved. | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I really like the variety in the tests! How do you feel about testing with some data from the pruna similar to what we have in tests/evaluation/test_torch_metrics.py?
| @@ -0,0 +1,200 @@ | |||
| # MSE Metric Implementation Summary | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you a lot for this detailed summary, are we planning on merging it to Pruna or is it more for giving information? I think this would be even more beneficial as the PR description
Description
Implements MSE (Mean Squared Error) metric for Pruna's evaluation framework. The metric computes mean squared error between model predictions and ground truth values, accumulating results across batches using StatefulMetric pattern.
Related Issue
Fixes #388
Type of Change
How Has This Been Tested?
Checklist
Additional Notes
Task(metrics=["mse"])